"""
Phase 1: Data Assembly for Fiscal Dominance Project
=====================================================
Build fiscal dominance panel from multilateral project data + WEO fiscal variables.

Source: multilateral/data/processed/full_panel.csv + multilateral/data/raw/weo_2025_1.csv
Output: fiscal_dominance/data/processed/fiscal_panel.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path

# Paths
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
FD_DIR = PROJECT_DIR / "fiscal_dominance"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)


def parse_weo_fiscal(weo_path):
    """
    Parse WEO raw CSV (tab-separated, UTF-16-LE) to extract fiscal variables.

    Returns DataFrame with columns: iso3, year, + fiscal variables.
    """
    print("\n  Parsing WEO fiscal variables...")

    # WEO file is UTF-16-LE with tab separators
    weo = pd.read_csv(weo_path, sep='\t', encoding='utf-16-le', na_values=['n/a', '--', '...'])

    # Identify year columns
    year_cols = [c for c in weo.columns if c.strip().isdigit()]
    id_cols = ['ISO', 'WEO Subject Code']

    # Fiscal variable codes to extract
    codes = {
        'GGXWDG_NGDP': 'govt_debt_gdp',           # Government gross debt / GDP
        'GGXWDN_NGDP': 'govt_net_debt_gdp',        # Government net debt / GDP
        'GGXONLB_NGDP': 'primary_bal_gdp',         # Primary balance / GDP
        'GGSB_NPGDP': 'structural_bal_gdp',        # Cyclically-adjusted balance / GDP
        'GGR_NGDP': 'govt_revenue_gdp',            # Government revenue / GDP
        'GGX_NGDP': 'govt_expenditure_gdp',        # Government expenditure / GDP
    }

    frames = []
    for code, name in codes.items():
        subset = weo[weo['WEO Subject Code'].str.strip() == code].copy()
        if len(subset) == 0:
            print(f"    WARNING: {code} ({name}) not found in WEO data")
            continue

        # Melt year columns to long format
        melted = subset.melt(
            id_vars=['ISO'], value_vars=year_cols,
            var_name='year_str', value_name=name
        )
        melted['iso3'] = melted['ISO'].str.strip()
        melted['year'] = melted['year_str'].str.strip().astype(int)

        # Clean numeric: remove commas, convert
        melted[name] = melted[name].astype(str).str.replace(',', '', regex=False)
        melted[name] = pd.to_numeric(melted[name], errors='coerce')

        melted = melted[['iso3', 'year', name]].dropna(subset=[name])
        frames.append(melted)
        print(f"    {code} -> {name}: {len(melted):,} obs")

    if not frames:
        raise ValueError("No fiscal variables extracted from WEO")

    # Merge all fiscal variables
    fiscal = frames[0]
    for f in frames[1:]:
        fiscal = fiscal.merge(f, on=['iso3', 'year'], how='outer')

    return fiscal


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly for Fiscal Dominance Project")
    print("=" * 70)

    # -----------------------------------------------------------------
    # 1. Load full_panel.csv
    # -----------------------------------------------------------------
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    print(f"\nLoaded full_panel: {fp.shape[0]:,} rows, {fp['iso3'].nunique()} countries, "
          f"{fp['year'].min()}-{fp['year'].max()}")

    # -----------------------------------------------------------------
    # 2. Filter to 1990-2024
    # -----------------------------------------------------------------
    df = fp[(fp['year'] >= 1990) & (fp['year'] <= 2024)].copy()
    print(f"After 1990-2024 filter: {len(df):,} rows, {df['iso3'].nunique()} countries")

    # -----------------------------------------------------------------
    # 3. Parse WEO fiscal variables
    # -----------------------------------------------------------------
    weo_path = MULTILATERAL_DIR / "data" / "raw" / "weo_2025_1.csv"
    fiscal = parse_weo_fiscal(weo_path)
    fiscal = fiscal[(fiscal['year'] >= 1990) & (fiscal['year'] <= 2024)]
    print(f"\n  WEO fiscal data: {len(fiscal):,} obs, {fiscal['iso3'].nunique()} countries")

    # -----------------------------------------------------------------
    # 4. Merge fiscal variables onto panel
    # -----------------------------------------------------------------
    fiscal_cols = [c for c in fiscal.columns if c not in ['iso3', 'year']]
    df = df.merge(fiscal[['iso3', 'year'] + fiscal_cols], on=['iso3', 'year'], how='left')
    for col in fiscal_cols:
        n = df[col].notna().sum()
        print(f"  {col}: {n:,} obs after merge")

    # -----------------------------------------------------------------
    # 5. Rate variable hierarchy (same as japanification)
    # -----------------------------------------------------------------
    df['rate_fd'] = np.nan
    df['rate_source'] = ''

    # Level 1: 10-year government bond yield
    m1 = df['govt_bond_10y'].notna()
    df.loc[m1, 'rate_fd'] = df.loc[m1, 'govt_bond_10y']
    df.loc[m1, 'rate_source'] = 'govt_bond_10y'

    # Level 2: policy rate
    m2 = df['rate_fd'].isna() & df['policy_rate'].notna()
    df.loc[m2, 'rate_fd'] = df.loc[m2, 'policy_rate']
    df.loc[m2, 'rate_source'] = 'policy_rate'

    # Level 3: lending rate
    m3 = df['rate_fd'].isna() & df['lending_rate'].notna()
    df.loc[m3, 'rate_fd'] = df.loc[m3, 'lending_rate']
    df.loc[m3, 'rate_source'] = 'lending_rate'

    rate_counts = df['rate_source'].value_counts()
    print(f"\nRate variable hierarchy:")
    for src, cnt in rate_counts.items():
        if src:
            print(f"  {src}: {cnt:,}")
    print(f"  missing: {(df['rate_source'] == '').sum():,}")

    # -----------------------------------------------------------------
    # 6. Compute r-g differentials
    # -----------------------------------------------------------------
    # Nominal rate - real growth (proxy; consistent with Blanchard 2019)
    df['r_minus_g'] = df['rate_fd'] - df['rgdp_growth']

    # Real rate - real growth (where possible)
    df['real_rate'] = df['rate_fd'] - df['inflation']
    df['r_minus_g_real'] = df['real_rate'] - df['rgdp_growth']

    print(f"\nr-g differential coverage:")
    print(f"  r_minus_g (nominal rate - real growth): {df['r_minus_g'].notna().sum():,}")
    print(f"  r_minus_g_real (real rate - real growth): {df['r_minus_g_real'].notna().sum():,}")

    # -----------------------------------------------------------------
    # 7. Compute delta_old_dep (speed of aging) and debt_change
    # -----------------------------------------------------------------
    df = df.sort_values(['iso3', 'year'])

    if 'delta_old_dep' not in df.columns:
        df['delta_old_dep'] = df.groupby('iso3')['old_dep'].diff()

    df['debt_lag'] = df.groupby('iso3')['govt_debt_gdp'].shift(1)
    df['debt_change'] = df['govt_debt_gdp'] - df['debt_lag']
    df['primary_bal_lag'] = df.groupby('iso3')['primary_bal_gdp'].shift(1)

    print(f"\n  delta_old_dep: {df['delta_old_dep'].notna().sum():,} obs")
    print(f"  debt_lag: {df['debt_lag'].notna().sum():,} obs")
    print(f"  debt_change: {df['debt_change'].notna().sum():,} obs")

    # -----------------------------------------------------------------
    # 8. HP-filtered output gap (for Bohn test)
    # -----------------------------------------------------------------
    # WEO output_gap only covers ~27 countries. Compute HP-filtered gap
    # from log(GDP PPP) to get 180+ countries.
    from statsmodels.tsa.filters.hp_filter import hpfilter

    print(f"\n  WEO output_gap coverage: {df['output_gap'].notna().sum():,} obs "
          f"({df.loc[df['output_gap'].notna(), 'iso3'].nunique()} countries)")

    df['hp_output_gap'] = np.nan
    hp_count = 0
    for iso3 in df['iso3'].unique():
        mask = (df['iso3'] == iso3) & df['gdp_pc_ppp'].notna()
        cdf = df.loc[mask].sort_values('year')
        if len(cdf) < 10:  # need reasonable time series
            continue
        log_gdp = np.log(cdf['gdp_pc_ppp'].values)
        try:
            cycle, trend = hpfilter(log_gdp, lamb=6.25)  # lambda=6.25 for annual data (Ravn-Uhlig)
            # Express as % deviation from trend
            gap_pct = cycle * 100
            df.loc[cdf.index, 'hp_output_gap'] = gap_pct
            hp_count += 1
        except Exception:
            continue

    print(f"  HP-filtered output_gap: {df['hp_output_gap'].notna().sum():,} obs "
          f"({hp_count} countries)")

    # Use HP gap as primary, WEO gap as fallback check
    df['output_gap_hp'] = df['hp_output_gap']

    # Government expenditure gap (deviation from trend)
    df['govt_exp_gap'] = np.nan
    for iso3 in df['iso3'].unique():
        mask = df['iso3'] == iso3
        series = df.loc[mask, 'govt_expenditure_gdp'].dropna()
        if len(series) >= 5:
            trend = series.rolling(window=5, min_periods=3, center=True).mean()
            df.loc[series.index, 'govt_exp_gap'] = series - trend
    print(f"  govt_exp_gap: {df['govt_exp_gap'].notna().sum():,} obs")

    # -----------------------------------------------------------------
    # 9. Winsorize fiscal variables at p1/p99
    # -----------------------------------------------------------------
    winsorize_vars = ['govt_debt_gdp', 'govt_net_debt_gdp', 'primary_bal_gdp',
                      'structural_bal_gdp', 'govt_revenue_gdp', 'govt_expenditure_gdp',
                      'r_minus_g', 'r_minus_g_real', 'debt_change']
    for var in winsorize_vars:
        if var in df.columns and df[var].notna().sum() > 50:
            p1 = df[var].quantile(0.01)
            p99 = df[var].quantile(0.99)
            n_clipped = ((df[var] < p1) | (df[var] > p99)).sum()
            df[var] = df[var].clip(lower=p1, upper=p99)
            print(f"  Winsorized {var}: [{p1:.1f}, {p99:.1f}], {n_clipped} obs clipped")

    # -----------------------------------------------------------------
    # 10. Add future OADR
    # -----------------------------------------------------------------
    if 'oadr_plus20' not in df.columns:
        future = pd.read_csv(MULTILATERAL_DIR / "data" / "processed" / "future_oadr.csv")
        df = df.merge(future[['iso3', 'year', 'oadr_plus20']], on=['iso3', 'year'], how='left')
    print(f"\n  Future OADR (oadr_plus20): {df['oadr_plus20'].notna().sum():,} obs")

    # -----------------------------------------------------------------
    # 11. High-debt indicator for interactions
    # -----------------------------------------------------------------
    df['high_debt'] = (df['govt_debt_gdp'] > 60).astype(float)
    df.loc[df['govt_debt_gdp'].isna(), 'high_debt'] = np.nan

    # Income group
    df['log_gdp_pc'] = np.log(df['gdp_pc_ppp'].clip(lower=100))
    median_gdp = df['gdp_pc_ppp'].median()
    df['high_income'] = (df['gdp_pc_ppp'] > median_gdp).astype(float)

    # OADR spline knots
    for knot in [15, 20, 25, 30]:
        col = f'oadr_above_{knot}'
        df[col] = np.maximum(df['old_dep'] - knot, 0)

    # -----------------------------------------------------------------
    # 12. Select columns and save
    # -----------------------------------------------------------------
    out_cols = [
        'iso3', 'year',
        # Demographics - polynomials
        'Z_1', 'Z_2', 'Z_3',
        # Demographics - dependency ratios
        'youth_dep', 'old_dep', 'total_dep', 'working_age_share',
        'delta_old_dep', 'oadr_plus20',
        # Demographics - other
        'life_expectancy',
        # Fiscal variables (NEW from WEO)
        'govt_debt_gdp', 'govt_net_debt_gdp', 'primary_bal_gdp',
        'structural_bal_gdp', 'govt_revenue_gdp', 'govt_expenditure_gdp',
        # Lagged/differenced fiscal
        'debt_lag', 'debt_change', 'primary_bal_lag',
        # Rate and r-g
        'rate_fd', 'rate_source', 'real_rate', 'r_minus_g', 'r_minus_g_real',
        'govt_bond_10y', 'policy_rate', 'lending_rate',
        'real_bond_10y', 'real_bond_10y_diff',
        # Gaps
        'output_gap', 'output_gap_hp', 'hp_output_gap', 'govt_exp_gap',
        # Controls
        'fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag',
        'expected_growth', 'rgdp_growth', 'inflation',
        'gdp_pc_ppp', 'log_gdp_pc', 'ngdp_usd',
        # CA for cross-validation
        'ca_gdp',
        # Interactions (pre-computed)
        'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen',
        # Indicators
        'high_debt', 'high_income',
        # OADR spline
        'oadr_above_15', 'oadr_above_20', 'oadr_above_25', 'oadr_above_30',
        # Pension (OECD subset)
        'pension_spending_gdp',
    ] + [f'd_n_{i}' for i in range(1, 18)]

    # Only keep columns that exist
    out_cols = [c for c in out_cols if c in df.columns]
    seen = set()
    out_cols = [c for c in out_cols if not (c in seen or seen.add(c))]

    panel = df[out_cols].copy()

    # -----------------------------------------------------------------
    # 13. Save
    # -----------------------------------------------------------------
    panel.to_csv(PROCESSED_DIR / "fiscal_panel.csv", index=False)
    print(f"\n{'=' * 70}")
    print(f"Saved: {PROCESSED_DIR / 'fiscal_panel.csv'}")
    print(f"  {len(panel):,} obs, {panel['iso3'].nunique()} countries, "
          f"{panel['year'].min()}-{panel['year'].max()}")
    print(f"  govt_debt_gdp: {panel['govt_debt_gdp'].notna().sum():,} obs "
          f"({panel.loc[panel['govt_debt_gdp'].notna(), 'iso3'].nunique()} countries)")
    print(f"  primary_bal_gdp: {panel['primary_bal_gdp'].notna().sum():,} obs")
    print(f"  r_minus_g: {panel['r_minus_g'].notna().sum():,} obs")
    print(f"  Z_1 (demographics): {panel['Z_1'].notna().sum():,} obs")

    # Summary statistics
    print(f"\n{'=' * 70}")
    print("SUMMARY STATISTICS — Key Variables")
    print("=" * 70)
    summary_vars = ['govt_debt_gdp', 'primary_bal_gdp', 'structural_bal_gdp',
                    'r_minus_g', 'r_minus_g_real', 'old_dep', 'Z_1', 'rate_fd',
                    'rgdp_growth', 'kaopen']
    summary_vars = [v for v in summary_vars if v in panel.columns]
    summary = panel[summary_vars].describe().T
    summary['non_missing'] = panel[summary_vars].notna().sum()
    print(summary[['non_missing', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]
          .to_string(float_format='%.3f'))

    summary.to_csv(TABLE_DIR / "phase1_summary_stats.csv")

    return panel


if __name__ == "__main__":
    panel = main()
